{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Understanding Tree SHAP for Simple Models\n", "\n", "The SHAP value for a feature is the average change in model output by conditioning on that feature when introducing features one at a time over all feature orderings. While this is easy to state, it is challenging to compute. So this notebook is meant to give a few simple examples where we can see how this plays out for very small trees. For arbitrary large trees it is very hard to intuitively guess these values by looking at the tree." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import graphviz\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.tree import DecisionTreeRegressor, export_graphviz\n", "\n", "import shap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single split example" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "0\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 100\n", "value = 0.5\n", "\n", "\n", "\n", "1\n", "\n", "squared_error = 0.0\n", "samples = 50\n", "value = 0.0\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "2\n", "\n", "squared_error = 0.0\n", "samples = 50\n", "value = 1.0\n", "\n", "\n", "\n", "0->2\n", "\n", "\n", "False\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build data\n", "N = 100\n", "M = 4\n", "X = np.zeros((N, M))\n", "X.shape\n", "y = np.zeros(N)\n", "X[: N // 2, 0] = 1\n", "y[: N // 2] = 1\n", "\n", "# fit model\n", "single_split_model = DecisionTreeRegressor(max_depth=1)\n", "single_split_model.fit(X, y)\n", "\n", "# draw model\n", "dot_data = export_graphviz(\n", " single_split_model,\n", " out_file=None,\n", " filled=True,\n", " rounded=True,\n", " special_characters=True,\n", ")\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explaining the model\n", "\n", "Note that the bias term is the expected output of the model over the training dataset (0.5). The SHAP value for features not used in the model is always 0, while for $x_0$ it is just the difference between the expected value and the output of the model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4
Example 0x1.01.01.01.0
shap_values0.50.00.00.0
Example 1x0.00.00.00.0
shap_values-0.50.00.00.0
\n", "
" ], "text/plain": [ " x1 x2 x3 x4\n", "Example 0 x 1.0 1.0 1.0 1.0\n", " shap_values 0.5 0.0 0.0 0.0\n", "Example 1 x 0.0 0.0 0.0 0.0\n", " shap_values -0.5 0.0 0.0 0.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = [np.ones(M), np.zeros(M)]\n", "df = pd.DataFrame()\n", "for idx, x in enumerate(xs):\n", " index = pd.MultiIndex.from_product([[f\"Example {idx}\"], [\"x\", \"shap_values\"]])\n", " df = pd.concat(\n", " [\n", " df,\n", " pd.DataFrame(\n", " [x, shap.TreeExplainer(single_split_model).shap_values(x)],\n", " index=index,\n", " columns=[\"x1\", \"x2\", \"x3\", \"x4\"],\n", " ),\n", " ]\n", " )\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two features AND example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use two features in this example. If feature $x_{0} = 1$ AND $x_{1} = 1$, the target value is one, else zero. Hence we call this the AND model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "0\n", " ≤ 0.5\n", "squared_error = 0.188\n", "samples = 100\n", "value = 0.25\n", "\n", "\n", "\n", "1\n", "\n", "squared_error = 0.0\n", "samples = 50\n", "value = 0.0\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "2\n", "\n", "x\n", "1\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 50\n", "value = 0.5\n", "\n", "\n", "\n", "0->2\n", "\n", "\n", "False\n", "\n", "\n", "\n", "3\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 0.0\n", "\n", "\n", "\n", "2->3\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 1.0\n", "\n", "\n", "\n", "2->4\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build data\n", "N = 100\n", "M = 4\n", "X = np.zeros((N, M))\n", "X.shape\n", "y = np.zeros(N)\n", "X[: 1 * N // 4, 1] = 1\n", "X[: N // 2, 0] = 1\n", "X[N // 2 : 3 * N // 4, 1] = 1\n", "y[: 1 * N // 4] = 1\n", "\n", "# fit model\n", "and_model = DecisionTreeRegressor(max_depth=2)\n", "and_model.fit(X, y)\n", "\n", "# draw model\n", "dot_data = export_graphviz(\n", " and_model, out_file=None, filled=True, rounded=True, special_characters=True\n", ")\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explaining the model\n", "\n", "Note that the bias term is the expected output of the model over the training dataset (0.25). The SHAP values for the unused features $x_2$ and $x_3$ are always 0. For $x_0$ and $x_1$ it is just the difference between the expected value (0.25) and the output of the model split equally between them (since they equally contribute to the AND function)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4
Example 0x1.0001.0001.01.0
shap_values0.3750.3750.00.0
Example 1x0.0000.0000.00.0
shap_values-0.125-0.1250.00.0
\n", "
" ], "text/plain": [ " x1 x2 x3 x4\n", "Example 0 x 1.000 1.000 1.0 1.0\n", " shap_values 0.375 0.375 0.0 0.0\n", "Example 1 x 0.000 0.000 0.0 0.0\n", " shap_values -0.125 -0.125 0.0 0.0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = np.array([np.ones(M), np.zeros(M)])\n", "# np.array([np.ones(M), np.zeros(M), np.array([1, 0, 1, 0]), np.array([0, 1, 0, 0])] # you can also check these examples\n", "df = pd.DataFrame()\n", "for idx, x in enumerate(xs):\n", " index = pd.MultiIndex.from_product([[f\"Example {idx}\"], [\"x\", \"shap_values\"]])\n", " df = pd.concat(\n", " [\n", " df,\n", " pd.DataFrame(\n", " [x, shap.TreeExplainer(and_model).shap_values(x)],\n", " index=index,\n", " columns=[\"x1\", \"x2\", \"x3\", \"x4\"],\n", " ),\n", " ]\n", " )\n", "df" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.25" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is how you get to the Shap values of Example 1:
\n", "The bias term (`y.mean()`) is 0.25, and the target value is 1. This leaves 1 - 0.27 = 0.75 to split among the relevant features. Since only $x_1$ and $x_2$ contribute to the target value (and to the same extent), it is divided among them, i.e., 0.375 for each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two features OR example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We do a slight variation of the example above. If $x_{0} = 1$ OR $x_{1} = 1$ the target is 1, else 0. Can you guess the SHAP values without scrolling down?" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "1\n", " ≤ 0.5\n", "squared_error = 0.188\n", "samples = 100\n", "value = 0.75\n", "\n", "\n", "\n", "1\n", "\n", "x\n", "0\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 50\n", "value = 0.5\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "4\n", "\n", "squared_error = 0.0\n", "samples = 50\n", "value = 1.0\n", "\n", "\n", "\n", "0->4\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 0.0\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 1.0\n", "\n", "\n", "\n", "1->3\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build data\n", "N = 100\n", "M = 4\n", "X = np.zeros((N, M))\n", "X.shape\n", "y = np.zeros(N)\n", "X[: N // 2, 0] = 1\n", "X[: 1 * N // 4, 1] = 1\n", "X[N // 2 : 3 * N // 4, 1] = 1\n", "y[: N // 2] = 1\n", "y[N // 2 : 3 * N // 4] = 1\n", "\n", "# fit model\n", "or_model = DecisionTreeRegressor(max_depth=2)\n", "or_model.fit(X, y)\n", "\n", "# draw model\n", "dot_data = export_graphviz(\n", " or_model, out_file=None, filled=True, rounded=True, special_characters=True\n", ")\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explaining the model\n", "\n", "Note that the bias term is the expected output of the model over the training dataset (0.75). The SHAP value for features not used in the model is always 0, while for $x_0$ and $x_1$ it is just the difference between the expected value and the output of the model split equally between them (since they equally contribute to the OR function)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4
Example 0x1.0001.0001.01.0
shap_values0.1250.1250.00.0
Example 1x0.0000.0000.00.0
shap_values-0.375-0.3750.00.0
\n", "
" ], "text/plain": [ " x1 x2 x3 x4\n", "Example 0 x 1.000 1.000 1.0 1.0\n", " shap_values 0.125 0.125 0.0 0.0\n", "Example 1 x 0.000 0.000 0.0 0.0\n", " shap_values -0.375 -0.375 0.0 0.0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = np.array([np.ones(M), np.zeros(M)])\n", "# np.array([np.ones(M), np.zeros(M), np.array([1, 0, 1, 0]), np.array([0, 1, 0, 0])] # you can also check these examples\n", "df = pd.DataFrame()\n", "for idx, x in enumerate(xs):\n", " index = pd.MultiIndex.from_product([[f\"Example {idx}\"], [\"x\", \"shap_values\"]])\n", " df = pd.concat(\n", " [\n", " df,\n", " pd.DataFrame(\n", " [x, shap.TreeExplainer(or_model).shap_values(x)],\n", " index=index,\n", " columns=[\"x1\", \"x2\", \"x3\", \"x4\"],\n", " ),\n", " ]\n", " )\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two feature XOR example" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "0\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 100\n", "value = 0.5\n", "\n", "\n", "\n", "1\n", "\n", "x\n", "1\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 50\n", "value = 0.5\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "4\n", "\n", "x\n", "1\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 50\n", "value = 0.5\n", "\n", "\n", "\n", "0->4\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 0.0\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 1.0\n", "\n", "\n", "\n", "1->3\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 1.0\n", "\n", "\n", "\n", "4->5\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 0.0\n", "\n", "\n", "\n", "4->6\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build data\n", "N = 100\n", "M = 4\n", "X = np.zeros((N, M))\n", "X.shape\n", "y = np.zeros(N)\n", "X[: N // 2, 0] = 1\n", "X[: 1 * N // 4, 1] = 1\n", "X[N // 2 : 3 * N // 4, 1] = 1\n", "y[1 * N // 4 : N // 2] = 1\n", "y[N // 2 : 3 * N // 4] = 1\n", "\n", "# fit model\n", "xor_model = DecisionTreeRegressor(max_depth=2)\n", "xor_model.fit(X, y)\n", "\n", "# draw model\n", "dot_data = export_graphviz(\n", " xor_model, out_file=None, filled=True, rounded=True, special_characters=True\n", ")\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explaining the model\n", "\n", "Note that the bias term is the expected output of the model over the training dataset (0.5). The SHAP value for features not used in the model is always 0, while for $x_0$ and $x_1$ it is just the difference between the expected value and the output of the model split equally between them (since they equally contribute to the XOR function)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4
Example 0x1.001.001.01.0
shap_values-0.25-0.250.00.0
Example 1x0.000.000.00.0
shap_values-0.25-0.250.00.0
\n", "
" ], "text/plain": [ " x1 x2 x3 x4\n", "Example 0 x 1.00 1.00 1.0 1.0\n", " shap_values -0.25 -0.25 0.0 0.0\n", "Example 1 x 0.00 0.00 0.0 0.0\n", " shap_values -0.25 -0.25 0.0 0.0" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = np.array([np.ones(M), np.zeros(M)])\n", "# np.array([np.ones(M), np.zeros(M), np.array([1, 0, 1, 0]), np.array([0, 1, 0, 0])] # you can also check these examples\n", "df = pd.DataFrame()\n", "for idx, x in enumerate(xs):\n", " index = pd.MultiIndex.from_product([[f\"Example {idx}\"], [\"x\", \"shap_values\"]])\n", " df = pd.concat(\n", " [\n", " df,\n", " pd.DataFrame(\n", " [x, shap.TreeExplainer(xor_model).shap_values(x)],\n", " index=index,\n", " columns=[\"x1\", \"x2\", \"x3\", \"x4\"],\n", " ),\n", " ]\n", " )\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two features AND + feature boost example" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "0\n", " ≤ 0.5\n", "squared_error = 0.688\n", "samples = 100\n", "value = 0.75\n", "\n", "\n", "\n", "1\n", "\n", "squared_error = 0.0\n", "samples = 50\n", "value = 0.0\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "2\n", "\n", "x\n", "1\n", " ≤ 0.5\n", "squared_error = 0.25\n", "samples = 50\n", "value = 1.5\n", "\n", "\n", "\n", "0->2\n", "\n", "\n", "False\n", "\n", "\n", "\n", "3\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 1.0\n", "\n", "\n", "\n", "2->3\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "squared_error = 0.0\n", "samples = 25\n", "value = 2.0\n", "\n", "\n", "\n", "2->4\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build data\n", "N = 100\n", "M = 4\n", "X = np.zeros((N, M))\n", "X.shape\n", "y = np.zeros(N)\n", "X[: N // 2, 0] = 1\n", "X[: 1 * N // 4, 1] = 1\n", "X[N // 2 : 3 * N // 4, 1] = 1\n", "y[: 1 * N // 4] = 1\n", "y[: N // 2] += 1\n", "\n", "# fit model\n", "and_fb_model = DecisionTreeRegressor(max_depth=2)\n", "and_fb_model.fit(X, y)\n", "\n", "# draw model\n", "dot_data = export_graphviz(\n", " and_fb_model, out_file=None, filled=True, rounded=True, special_characters=True\n", ")\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explain the model\n", "\n", "Note that the bias term is the expected output of the model over the training dataset (0.75). The SHAP value for features not used in the model is always 0, while for $x_0$ and $x_1$ it is just the difference between the expected value and the output of the model split equally between them (since they equally contribute to the AND function), plus an extra 0.5 impact for $x_0$ since it has an effect of $1.0$ all by itself (+0.5 if it is on and -0.5 if it is off)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2x3x4
Example 0x1.0001.0001.01.0
shap_values0.8750.3750.00.0
Example 1x0.0000.0000.00.0
shap_values-0.625-0.1250.00.0
\n", "
" ], "text/plain": [ " x1 x2 x3 x4\n", "Example 0 x 1.000 1.000 1.0 1.0\n", " shap_values 0.875 0.375 0.0 0.0\n", "Example 1 x 0.000 0.000 0.0 0.0\n", " shap_values -0.625 -0.125 0.0 0.0" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = np.array([np.ones(M), np.zeros(M)])\n", "# np.array([np.ones(M), np.zeros(M), np.array([1, 0, 1, 0]), np.array([0, 1, 0, 0])] # you can also check these examples\n", "df = pd.DataFrame()\n", "for idx, x in enumerate(xs):\n", " index = pd.MultiIndex.from_product([[f\"Example {idx}\"], [\"x\", \"shap_values\"]])\n", " df = pd.concat(\n", " [\n", " df,\n", " pd.DataFrame(\n", " [x, shap.TreeExplainer(and_fb_model).shap_values(x)],\n", " index=index,\n", " columns=[\"x1\", \"x2\", \"x3\", \"x4\"],\n", " ),\n", " ]\n", " )\n", "df" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 1 }